import os
import numpy as np
import torch
from datasets import load_dataset
import random
import io
import json
import sys
import pandas as pd
from typing import List, Dict, Optional
from tqdm import tqdm



def set_seed(seed):
    """设置随机种子以确保结果可重现"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True



def sample_train_loaders(name, tokenizer, nsamples=128, seed=0, seqlen=2048):
    set_seed(seed)
    if "wikitext2" in name:
        traindata = load_dataset(
            "wikitext",
            "wikitext-2-raw-v1",
            split="train",
        )
        traindata = "\n\n".join(traindata["text"])
    elif "c4" in name:
        traindata = load_dataset(
            "allenai/c4",
            "allenai--c4",
            data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
            split="train",
        )
        traindata = "\n\n".join(traindata["text"])
    else:
        raise NotImplementedError

    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, len(traindata) - seqlen * 2 - 1)
        j = i + seqlen * 2
        # breakpoint()
        trainenc = tokenizer(traindata[i:j], return_tensors="pt")
        inp = trainenc.input_ids[:, :seqlen]
        trainloader.append(inp)
    return trainloader


def get_redpajama_train(tokenizer, percent=10, seed=3, batch_size=128, max_length=2048):
    def tokenization(example):
        return tokenizer(example["text"], truncation=True, max_length=max_length)

    if percent != 100:
        split = f"train[:{int(850000*percent/100)}]"
    else:
        split = "train"
    dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", split=split)

    processed_dataset = dataset.map(
        tokenization, batched=True, batch_size=batch_size, num_proc=os.cpu_count()
    )
    return processed_dataset


def get_english_quote(dataset_name, tokenizer):
    data = load_dataset(dataset_name)
    data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
    return data["train"]


def get_qat_dataset(name, tokenizer, data_percent):
    if name == "red_pajama":
        data = get_redpajama_train(tokenizer, data_percent)

    elif name == "Abirate/english_quotes":
        data = get_english_quote(name, tokenizer)
    else:
        raise NotImplementedError
    data = data.shuffle()
    return data

'''
llama_chat_format="""<s>[INST] <<SYS>>
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
<</SYS>>

{{ instruction }} [/INST] {{ response }} </s>
"""
'''

llama_chat_format="""<s>[INST] <<SYS>>
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
<</SYS>>

{instruction} [/INST] {response} </s>
"""


def _make_r_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f = open(f, mode=mode)
        #f = open(f)
    return f

def jload(f, mode="r"):
    """Load a .json file into a dictionary."""
    f = _make_r_io_base(f, mode)
    jdict = json.load(f)
    f.close()
    return jdict


def load_infoseek_data(question_file: str, nsamples: int, seed: int = 3, save_sampled_data: bool = True, output_file: str = "sampled_infoseek_data.jsonl", max_tokens: int = 2048,tokenizer=None):
    """
    加载infoseek数据集
    
    Args:
        question_file: jsonl文件路径
        nsamples: 采样数量
        seed: 随机种子
        save_sampled_data: 是否保存采样的数据到jsonl文件
        output_file: 输出文件名
        max_tokens: �� 最大token数量限制，默认为2048
    
    Returns:
        包含图像和文本数据的列表
    """
    set_seed(seed)
    
    # 读取jsonl文件
    data_list = []
    with open(question_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data_list.append(json.loads(line))
    
    print(f"总共读取到 {len(data_list)} 条数据")
    
    # 随机采样
    if nsamples < len(data_list):
        sampled_data = random.sample(data_list, nsamples)
    else:
        sampled_data = data_list
        print(f"⚠ 请求的样本数 {nsamples} 大于数据集大小 {len(data_list)}，使用所有数据")
    
    processed_data = []
    sampled_data_for_save = []  # 用于保存到jsonl的数据
    
    for i, item in enumerate(tqdm(sampled_data, desc="处理infoseek数据")):
        try:
            # 处理文本：将section_texts中的内容用空格连接
            section_texts = item.get('section_texts', [])
            if isinstance(section_texts, list):
                # 过滤掉空字符串，然后用空格连接
                text_parts = [text.strip() for text in section_texts if text.strip()]
                text = ' '.join(text_parts)
            else:
                text = str(section_texts) if section_texts else ""
            
            # 🔧 添加文本长度检查和截断逻辑
            if len(text) == 0:
                print(f"⚠ 第 {i} 条数据文本为空，跳过")
                continue
            
            # 处理本地图像路径
            local_image_path = item.get('local_image_path', '')
            if not local_image_path:
                print(f"⚠ 第 {i} 条数据没有本地图像路径，跳过")
                continue
            
            # 检查图像文件是否存在
            if not os.path.exists(local_image_path):
                print(f"⚠ 图像文件不存在: {local_image_path}")
                continue
            
            # 验证图像是否可以正常打开
            try:
                with Image.open(local_image_path) as img:
                    img.verify()
                    # 重新打开图像以获取尺寸信息
                    img = Image.open(local_image_path)
                    img_size = img.size
            except Exception as e:
                print(f"⚠ 图像文件损坏或无法打开: {local_image_path}, 错误: {e}")
                continue
            
            # �� 使用tokenizer进行精确的文本截断
            # 先编码，检查token数量
            temp_enc = tokenizer(text, return_tensors="pt")
            original_token_count = temp_enc.input_ids.shape[1]
            
            if original_token_count > max_tokens:
                print(f"⚠ 第 {i} 条数据文本过长（{original_token_count}个token），进行截断")
                print(f"  原始文本长度: {len(text)} 字符")
                
                # 截断到max_tokens个token
                truncated_tokens = temp_enc.input_ids[:, :max_tokens]
                # 解码回文本
                text = tokenizer.decode(truncated_tokens[0], skip_special_tokens=True)
                
                print(f"  截断后文本长度: {len(text)} 字符")
                print(f"  截断掉的token数量: {original_token_count - max_tokens}")
                
                # 验证截断后的token数量
                verify_enc = tokenizer(text, return_tensors="pt")
                print(f"  验证截断后token数量: {verify_enc.input_ids.shape[1]}")
            else:
                print(f"✓ 第 {i} 条数据文本长度正常（{original_token_count}个token）")
            
            # 构建处理后的数据
            processed_item = {
                'image': local_image_path,
                'text': text,  # 🔧 这里保存的是截断后的文本
                'original_index': i,
                'image_size': img_size
            }
            
            processed_data.append(processed_item)
            
            # 构建用于保存的数据（包含更多原始信息）
            save_item = {
                'wikipedia_url': item.get('wikipedia_url', ''),
                'question': item.get('question', ''),
                'answer': item.get('answer', ''),
                'image_urls': item.get('image_urls', ''),
                'section_texts': item.get('section_texts', []),
                'local_image_path': local_image_path,
                'processed_text': text,  # 🔧 这里保存的是截断后的文本
                'original_index': i,
                'sampled_index': len(processed_data) - 1,
                'image_size': img_size
            }
            sampled_data_for_save.append(save_item)
                
        except Exception as e:
            print(f"⚠ 处理第 {i} 条数据时出错: {e}")
            continue
    
    print(f"✓ 成功处理 {len(processed_data)} 条数据")
    
    # 保存采样的数据到jsonl文件
    if save_sampled_data:
        try:
            # 创建输出目录（如果不存在）
            output_dir = output_file.rsplit('/', 1)[0] if '/' in output_file else "."
            os.makedirs(output_dir, exist_ok=True)
            
            with open(output_file, 'w', encoding='utf-8') as f:
                for sample in sampled_data_for_save:
                    f.write(json.dumps(sample, ensure_ascii=False) + '\n')
            
            print(f"✓ 采样的infoseek数据已保存到: {output_file}")
            print(f"  - 包含 {len(sampled_data_for_save)} 个样本")
            
        except Exception as e:
            print(f"⚠ 保存数据时出错: {e}")
    
    return processed_data






def get_calib_data(name, tokenizer, model_id, nsamples, seqlen=2048, seed=3, question_file=None):
    print(f" get_data_from: {name}, nsamples={nsamples}, seqlen={seqlen}, {seed}")
    cache_file = (
        f"cache/{name}_{model_id.replace('/','_')}_{nsamples}_{seqlen}_{seed}.pt"
    )
    random.seed(seed)
    if not os.path.exists("cache"):
        os.makedirs("cache")
    if os.path.exists(cache_file):
        print(f"found data file: {cache_file}")
        traindataset = torch.load(cache_file)
        print("loaded ...")
        return traindataset
    
    # 添加infoseek数据集支持
    if name == "infoseek":
        if question_file is None:
            raise ValueError("infoseek数据集需要提供question_file参数")
        
        print(f"加载infoseek数据集: {question_file}")
        infoseek_data = load_infoseek_data(question_file, nsamples, seed)
        
        traindataset = []
        for item in infoseek_data:
            # 使用文本作为输入
            trainenc = tokenizer(item['text'], return_tensors="pt")
            inp = trainenc.input_ids[:, :seqlen]


        # 打印原始tokenized输入的形状
            print(f"Original input_ids shape: {trainenc.input_ids.shape}")
            print(f"Original input_ids length: {trainenc.input_ids.shape[1]}")
            
            # 打印seqlen的值
            print(f"seqlen: {seqlen}")
            
            # 裁剪操作
            inp = trainenc.input_ids[:, :seqlen]
            
            # 打印裁剪后的形状
            print(f"Cropped input_ids shape: {inp.shape}")
            print(f"Cropped input_ids length: {inp.shape[1]}")
            
            # 判断是否发生了裁剪
            if trainenc.input_ids.shape[1] > seqlen:
                print("✅ 发生了裁剪操作")
                print(f"裁剪掉的token数量: {trainenc.input_ids.shape[1] - seqlen}")
            else:
                print("❌ 没有发生裁剪操作")
            
            print("-" * 50)



            attention_mask = torch.ones_like(inp)
            traindataset.append({
                "input_ids": inp, 
                "attention_mask": attention_mask,
                "image_path": item['image']  # 保存图像路径供后续使用
            })
        
        # print(f"example text: {item['text'][:10]}...")
        torch.save(traindataset, cache_file)
        return traindataset


    elif name == "c4":
        traindata = load_dataset(
            "allenai/c4",
            "allenai--c4",
            data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
            split="train",
        )
        tot_text = "\n\n".join(traindata["text"])
    elif name == "wikitext2":
        traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
        tot_text = "\n\n".join(traindata["text"])
    elif name=="ptb":
        traindata = load_dataset(
            "ptb_text_only",
            "penn_treebank",
            split="train",
        )
        tot_text = "\n\n".join(traindata["sentence"])
    elif name == "traivia_qa":
        traindata = load_dataset("trivia_qa", "rc", split="train")
        tot_text = "\n\n".join(traindata["question"])
    elif name == "nqopen":
        traindata = load_dataset("nq_open", split="train")
        tot_text = "\n\n".join(traindata["question"])        
    elif name == "alpaca":
        # this is for chat models
        data_path="data/alpaca_data.json"
        list_data_dict = jload(data_path)
        traindataset =[]
        selected_data_dict=random.sample(list_data_dict, nsamples)
        #random_indices = np.random.choice(len(list_data_dict), nsamples, replace=False)
        #selected_data_dict = [list_data_dict[i] for i in random_indices]
        for example in selected_data_dict:
            if example.get("input", "") == "":
                s=llama_chat_format.format(instruction=example["instruction"], response=example["output"])
                trainenc=tokenizer(s, return_tensors="pt")
                inp=trainenc.input_ids[:, :seqlen]
                attention_mask = torch.ones_like(inp)
                traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
        print("example instruction:", s)
        torch.save(traindataset, cache_file)
        return traindataset
    elif name == "MetaMATH":
        data_path="data/MetaMathQA-395K.json"
        list_data_dict = jload(data_path)
        traindataset =[]
        selected_data_dict=random.sample(list_data_dict, nsamples)
        for example in selected_data_dict:
            if example.get("input", "") == "":
                s=llama_chat_format.format(instruction=example["query"], response=example["response"])
                trainenc=tokenizer(s, return_tensors="pt")
                inp=trainenc.input_ids[:, :seqlen]
                attention_mask = torch.ones_like(inp)
                traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
        print("example instruction:", s)        
        torch.save(traindataset, cache_file)
        return traindataset
    elif name == "codefeedback":
        data_path="data/CodeFeedback-Filtered-Instruction.jsonl"
        with open(data_path, 'r') as json_file:
            json_list = list(json_file)
        print(len(json_list))
        list_data_dict = []
        for item in json_list:
            dict_item = json.loads(item)
            list_data_dict.append(dict_item)
            assert isinstance(dict_item, dict)
        #list_data_dict = jload(data_path)
        traindataset =[]
        #selected_data_dict=random.sample(list_data_dict, nsamples)
        random_indices = np.random.choice(len(list_data_dict), nsamples, replace=False)
        selected_data_dict = [list_data_dict[i] for i in random_indices]        
        for example in selected_data_dict:
            if example.get("input", "") == "":
                s=llama_chat_format.format(instruction=example["query"], response=example["answer"])
                trainenc=tokenizer(s, return_tensors="pt")
                inp=trainenc.input_ids[:, :seqlen]
                attention_mask = torch.ones_like(inp)
                traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
        print("example instruction:", s) 
        torch.save(traindataset, cache_file)
        return traindataset
    elif name == "WizLMinstruct":
        data_path="data/WizardLM_evol_instruct_V2_143k.jsonl"
        with open(data_path, 'r') as json_file:
            json_list = list(json_file)
        print(len(json_list))
        list_data_dict = []
        for item in json_list:
            dict_item = json.loads(item)
            list_data_dict.append(dict_item)
            assert isinstance(dict_item, dict)
        #list_data_dict = jload(data_path)
        traindataset =[]
        selected_data_dict=random.sample(list_data_dict, nsamples)
        for example in selected_data_dict:
            if example.get("input", "") == "":
                s=llama_chat_format.format(instruction=example["conversation"][0]["human"], response=example["conversation"][0]["assistant"])
                trainenc=tokenizer(s, return_tensors="pt")
                inp=trainenc.input_ids[:, :seqlen]
                attention_mask = torch.ones_like(inp)
                traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
        print("example instruction:", s)        
        torch.save(traindataset, cache_file)
        return traindataset        
    else:
        raise NotImplementedError
    print(f"tot_text={len(tot_text)}")
    traindataset = []
    for _ in range(nsamples):
        i = random.randint(0, len(tot_text) - seqlen - 1)
        j = i + seqlen * 10
        trainenc = tokenizer(tot_text[i:j], return_tensors="pt")
        inp = trainenc.input_ids[:, :seqlen]
        attention_mask = torch.ones_like(inp)
        traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
    torch.save(traindataset, cache_file)
    return traindataset


def get_eval_loaders(name, tokenizer):
    if "wikitext2" in name:
        testdata = load_dataset(
            "wikitext",
            "wikitext-2-raw-v1",
            split="test",
        )
        testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
        return testenc
    if "ptb" in name:
        valdata = load_dataset(
            "ptb_text_only",
            "penn_treebank",
            split="validation",
        )
        testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
        return testenc
    if "c4" in name:
        testdata = load_dataset(
            "allenai/c4",
            "allenai--c4",
            data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
            split="validation",
        )
        testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
        return testenc        
    raise NotImplementedError






########################################LLaVA load benchmark data #####################################################

# 添加LLaVA路径
sys.path.append("/home/bingxing2/ailab/scx6mh7/jkl/LLaVA_8_8_null_space")
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import math

# 导入benchmark_load.py中的功能
from benchmark_load import process_vlmeval_datasets, DATASET_CONFIG




def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]

def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]

def load_benchmark_datasets(dataset_names: List[str], n_samples_per_dataset: int, 
                          data_root: str = "/home/jiangkailin/mydisk/iclr26_evoke_dynamic_null_space/cache/vlmeval",
                          seed: int = 233, save_sampled_data: bool = True, 
                          output_file: str = "sampled_benchmark_data.jsonl") -> List[Dict]:
    """
    加载指定的benchmark数据集并采样数据
    
    Args:
        dataset_names: 数据集名称列表，如 ['MME', 'OCRBench']
        n_samples_per_dataset: 每个数据集采样的样本数
        data_root: 数据存储根目录
        seed: 随机种子，用于确保采样结果可重现
        save_sampled_data: 是否保存采样的数据到jsonl文件
        output_file: 输出文件名
    
    Returns:
        包含所有数据集采样数据的列表
    """
    # 设置随机种子以确保结果可重现
    set_seed(seed)
    print(f"使用随机种子: {seed}")
    
    print(f"正在加载数据集: {dataset_names}")
    print(f"每个数据集采样 {n_samples_per_dataset} 个样本")
    
    # 处理数据集
    dataset_results = process_vlmeval_datasets(dataset_names, data_root)
    
    all_sampled_data = []
    all_original_data = []  # 存储完整的原始数据
    
    for dataset_name, dataset_result in dataset_results.items():
        if dataset_result is None:
            print(f"⚠ 数据集 {dataset_name} 处理失败，跳过")
            continue
            
        print(f"处理数据集: {dataset_name}")
        print(f"  总样本数: {dataset_result['total_samples']}")
        print(f"  图像数: {dataset_result['image_count']}")
        
        # 获取数据
        data = dataset_result['data']
        image_paths = dataset_result['image_paths']
        
        # 首先过滤出同时包含图像和问题的有效数据
        valid_indices = []
        for idx, row in data.iterrows():
            # 检查是否有问题文本
            has_question = not pd.isna(row.get('question', '')) and str(row.get('question', '')).strip() != ''
            
            # 检查是否有对应的图像
            has_image = str(row['index']) in image_paths
            
            # 只有当同时包含图像和问题时，才认为是有效数据
            if has_question and has_image:
                valid_indices.append(idx)
        
        print(f"  有效样本数（同时包含图像和问题）: {len(valid_indices)}")
        
        # 如果没有有效数据，跳过这个数据集
        if len(valid_indices) == 0:
            print(f"⚠ 数据集 {dataset_name} 没有有效样本（同时包含图像和问题），跳过")
            continue
        
        # 从有效数据中进行随机采样（使用固定的随机种子）
        if n_samples_per_dataset < len(valid_indices):
            sampled_valid_indices = random.sample(valid_indices, n_samples_per_dataset)
        else:
            sampled_valid_indices = valid_indices
            print(f"⚠ 请求的样本数 {n_samples_per_dataset} 大于数据集 {dataset_name} 的有效样本数 {len(valid_indices)}，使用所有有效样本")
        
        ####################################################################
        # 构建采样数据
        dataset_sampled_data = []
        for idx in sampled_valid_indices:
            row = data.iloc[idx]
            # 处理索引：尝试转换为整数，如果失败则保持字符串格式
            try:
                original_index = int(row['index'])
            except (ValueError, TypeError):
                original_index = str(row['index'])
            
            sample = {
                'dataset_name': dataset_name,
                'image': image_paths[str(row['index'])],  # 直接使用图像路径，因为已经确认存在
                'text': row.get('question', ''),  # 使用question作为text
                'original_index': original_index,  # 保存原始索引（可能是字符串或整数）
                'sampled_index': len(dataset_sampled_data)  # 采样后的索引
            }
            
            dataset_sampled_data.append(sample)
            all_sampled_data.append(sample)
        ####################################################################

        dataset_original_data = []
        for idx in valid_indices:
            row = data.iloc[idx]
            # 处理索引：尝试转换为整数，如果失败则保持字符串格式
            try:
                original_index = int(row['index'])
            except (ValueError, TypeError):
                original_index = str(row['index'])
            
            original_sample = {
                'dataset_name': dataset_name,
                'image': image_paths[str(row['index'])],
                'text': row.get('question', ''),
                'original_index': original_index,
                'is_sampled': idx in sampled_valid_indices  # 标记是否被采样
            }
            dataset_original_data.append(original_sample)
        
        all_original_data.extend(dataset_original_data)
        
        print(f"  ✓ 从数据集 {dataset_name} 采样了 {len(dataset_sampled_data)} 个样本")
    
    print(f"✓ 总共采样了 {len(all_sampled_data)} 个样本")
    
    # 保存采样的数据到jsonl文件
    if save_sampled_data:
        try:
            # 保存采样数据
            with open(output_file, 'w', encoding='utf-8') as f:
                for sample in all_sampled_data:
                    f.write(json.dumps(sample, ensure_ascii=False) + '\n')
            
            # 保存完整原始数据（包含采样标记）
            original_output_file = output_file.replace('.jsonl', '_original_complete.jsonl')
            with open(original_output_file, 'w', encoding='utf-8') as f:
                for sample in all_original_data:
                    f.write(json.dumps(sample, ensure_ascii=False) + '\n')
            
            print(f"✓ 采样数据已保存到: {output_file}")
            print(f"✓ 完整原始数据已保存到: {original_output_file}")
            print(f"  - 采样数据包含 {len(all_sampled_data)} 个样本")
            print(f"  - 完整数据包含 {len(all_original_data)} 个样本")
            
        except Exception as e:
            print(f"⚠ 保存数据时出错: {e}")
    
    return all_sampled_data












# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, questions, image_folder, tokenizer, image_processor, model_config, conv_mode,output_file, n_samples=None):
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.model_config = model_config
        self.conv_mode = conv_mode
        self.output_file = output_file
        #####################################
        # Randomly select n_samples if specified
        if n_samples is not None and n_samples < len(questions):
            self.questions = random.sample(questions, n_samples)
        else:
            self.questions = questions
        #####################################
    
    def __getitem__(self, index):
        line = self.questions[index]

        # 实时保存line到jsonl文件（每次访问都保存）
        if self.output_file:
            try:
                with open(self.output_file, 'a', encoding='utf-8') as f:
                    json.dump(line, f, ensure_ascii=False)
                    f.write('\n')
            except Exception as e:
                print(f"保存line数据到jsonl文件时出错: {e}")
        
        # 处理图像文件路径
        if isinstance(line, dict) and 'image' in line:
            # 新的benchmark数据格式
            image_file = line['image']
            qs = line['text']
            print('----------------使用benchmark_load--------------------------')
            print('image_file:',image_file)
            print('qs:',qs)
            # 如果图像路径是绝对路径，直接使用
            if image_file and os.path.isabs(image_file):
                image_path = image_file
            elif image_file:
                # 否则拼接image_folder
                image_path = os.path.join(self.image_folder, image_file)
            else:
                # 没有图像的情况，跳过这条数据
                raise ValueError(f"样本 {index} 没有图像文件")
        else:
            # 原有的数据格式
            image_file = line["image"]
            qs = line["text"]
            image_path = os.path.join(self.image_folder, image_file)
        
        # 检查图像文件是否存在
        if not os.path.exists(image_path):
            raise ValueError(f"图像文件不存在: {image_path}")
        
        if self.model_config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        # 处理图像
        image = Image.open(image_path).convert('RGB')
        image_tensor = process_images([image], self.image_processor, self.model_config)[0]
        image_size = image.size

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        
        return input_ids, image_tensor, image_size

    def __len__(self):
        return len(self.questions)

def collate_fn(batch):
    input_ids, image_tensors, image_sizes = zip(*batch)
    input_ids = torch.stack(input_ids, dim=0)
    image_tensors = torch.stack(image_tensors, dim=0)
    return input_ids, image_tensors, image_sizes

def create_data_loader_from_benchmark(dataset_names: List[str], n_samples_per_dataset: int, 
                                    tokenizer, image_processor, model_config, conv_mode,
                                    batch_size=1, num_workers=4, 
                                    data_root="/home/bingxing2/ailab/scx6mh7/jkl/ckpt_sum/data/vlmeval",
                                    seed=233, save_sampled_data=True, output_file="sampled_benchmark_data.jsonl"):
    """
    从benchmark数据集创建数据加载器
    
    Args:
        dataset_names: 数据集名称列表，如 ['MME', 'OCRBench']
        n_samples_per_dataset: 每个数据集采样的样本数
        tokenizer: 分词器
        image_processor: 图像处理器
        model_config: 模型配置
        conv_mode: 对话模式
        batch_size: 批次大小（必须为1）
        num_workers: 工作进程数
        data_root: 数据存储根目录
        seed: 随机种子
        save_sampled_data: 是否保存采样的数据
        output_file: 输出文件名
    
    Returns:
        DataLoader对象
    """
    assert batch_size == 1, "batch_size must be 1"
    
    # 加载和采样数据
    questions = load_benchmark_datasets(dataset_names, n_samples_per_dataset, data_root, 
                                      seed, save_sampled_data, output_file)
    
    # 创建数据集
    dataset = CustomDataset(questions, "", tokenizer, image_processor, model_config, conv_mode,output_file, n_samples=None)
    
    # 创建数据加载器
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, 
                           shuffle=False, collate_fn=collate_fn)
    
    return data_loader





def create_data_loader_from_infoseek(question_file: str, n_samples: int, 
                                   tokenizer, image_processor, model_config, conv_mode,
                                   batch_size=1, num_workers=4, seed=233, 
                                   save_sampled_data=True, output_file="sampled_infoseek_data.jsonl"):
    """
    从infoseek数据集创建数据加载器
    
    Args:
        question_file: jsonl文件路径
        n_samples: 采样数量
        tokenizer: 分词器
        image_processor: 图像处理器
        model_config: 模型配置
        conv_mode: 对话模式
        batch_size: 批次大小（必须为1）
        num_workers: 工作进程数
        seed: 随机种子
        save_sampled_data: 是否保存采样的数据
        output_file: 输出文件名
    
    Returns:
        DataLoader对象
    """
    assert batch_size == 1, "batch_size must be 1"
    
    # �� 修改：传入tokenizer和max_tokens参数
    questions = load_infoseek_data(question_file, n_samples, seed, save_sampled_data, output_file, max_tokens=2048, tokenizer=tokenizer)
    
    # 创建数据集
    dataset = CustomDataset(questions, "", tokenizer, image_processor, model_config, conv_mode, n_samples=None)
    
    # 创建数据加载器
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, 
                           shuffle=False, collate_fn=collate_fn)
    
    return data_loader



# DataLoader
def create_data_loader(nsamples, questions, image_folder, tokenizer, image_processor, model_config, conv_mode, output_file,batch_size=1, num_workers=4):
    assert batch_size == 1, "batch_size must be 1"
    dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config, conv_mode, output_file,nsamples) 
    print('dataset',dataset)   
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
    return data_loader

def get_calib_data_mllm(args):
    # Model
    disable_torch_init()
    model_path = os.path.expanduser(args.model_id)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
            device_map="auto", torch_dtype=torch.float32, trust_remote_code=True)
    #####################################
    model = model.to(dtype=torch.float32)
    #####################################
    
    # 检查是否使用infoseek数据集
    if hasattr(args, 'dataset_names') and args.dataset_names is not None:
        if 'infoseek' in args.dataset_names:
            print(f"使用infoseek数据集")
            print(f"采样 {args.n_samples_per_dataset} 个样本")
            
            # 检查是否提供了jsonl路径
            if not hasattr(args, 'question_file') or args.question_file is None:
                raise ValueError("使用infoseek数据集时必须提供question_file参数")
            
            print(f"jsonl文件路径: {args.question_file}")
            
            # 生成输出文件名（包含seed信息）
            output_file = getattr(args, 'output_file', f"sampled_infoseek_data_seed{args.seed}.jsonl")
            
            # 使用infoseek数据加载函数
            data_loader = create_data_loader_from_infoseek(
                args.question_file,
                args.n_samples_per_dataset, 
                tokenizer, 
                image_processor, 
                model.config, 
                args.conv_mode,
                seed=args.seed,
                save_sampled_data=True,
                output_file=output_file
            )
        else:
            # 原有的benchmark数据集处理
            print(f"使用benchmark数据集: {args.dataset_names}")
            print(f"每个数据集采样 {args.n_samples_per_dataset} 个样本")
            
            # 生成输出文件名（包含seed信息）
            output_file = args.output_file

            # 创建输出目录（如果不存在）
            output_dir = output_file.rsplit('/', 1)[0] if '/' in output_file else "."
            os.makedirs(output_dir, exist_ok=True)
            print(f"✓ 创建输出目录: {output_dir}")
            print(f"✓ 输出文件路径: {output_file}")
            
            # 使用新的benchmark数据加载函数
            data_loader = create_data_loader_from_benchmark(
                args.dataset_names, 
                args.n_samples_per_dataset, 
                tokenizer, 
                image_processor, 
                model.config, 
                args.conv_mode,
                data_root=getattr(args, 'data_root', "/home/jiangkailin/mydisk/iclr26_evoke_dynamic_null_space/cache/vlmeval"),
                seed=args.seed,
                save_sampled_data=True,
                output_file=output_file
            )
    else:
        # 检查是否是 ScienceQA 数据集
        if hasattr(args, 'calib_dataset') and args.calib_dataset == 'scienceqa':
            print("使用 ScienceQA 数据加载方式")
            # ScienceQA 数据加载方式
            questions = json.load(open(os.path.expanduser(args.question_file), "r"))
            questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
            
            # 处理 ScienceQA 数据格式，过滤掉没有图像的样本
            processed_questions = []
            for line in questions:
                if 'conversations' in line and len(line['conversations']) > 0:
                    # 提取问题文本，移除 <image> 标签
                    qs = line['conversations'][0]['value'].replace('<image>', '').strip()
                    
                    # 检查是否有图像
                    image_file = line.get('image', '')
                    if image_file and image_file.strip():  # 只有当图像文件存在且不为空时才添加
                        # 构建新的数据格式
                        processed_line = {
                            'text': qs,
                            'image': image_file
                        }
                        processed_questions.append(processed_line)
            
            print(f"ScienceQA 原始数据: {len(questions)} 条")
            print(f"ScienceQA 有效数据（有图像）: {len(processed_questions)} 条")
            
            questions = processed_questions
            
            if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
                args.conv_mode = args.conv_mode + '_mmtag'
                print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
            
            data_loader = create_data_loader(args.calib_loader_size, questions, args.image_folder, tokenizer, image_processor, model.config, args.conv_mode, args.output_file)
        else:
            # 原有的数据加载方式
            questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
            questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
            if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
                args.conv_mode = args.conv_mode + '_mmtag'
                print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
            data_loader = create_data_loader(args.calib_loader_size, questions, args.image_folder, tokenizer, image_processor, model.config, args.conv_mode, args.output_file)
    
    return model, data_loader